Amazon SageMaker上で学習させたMXNetのモデルで、自分が書いた数字を分類させてみた
はじめに
Amazon SageMakerは、機械学習のモデルの学習からホストまで簡単に行うことができるフルマネージドサービスです。
今回は、Amazon SageMakerを使ってMXNetの分類モデルを学習させるサンプルノートブックをやってみたので、その内容をご紹介します。
概要
MNISTのデータセットを教師データとして、MXNetで分類モデルを学習させます。作成したモデルで自らが書いた数字を分類します。
学習処理とモデルのホスティングはSageMaker Python SDKを介してAmazon SageMaker上で行います。
- MNISTのデータセットは、手書き数字画像とラベルデータを含んでいます。データ分析のチュートリアルでよく使われるデータセットの一つです。
基本的にはこのノートブックに沿って進めますが、一部改変しています。では、実際にやってみましょう。
やってみた
準備
ノートブックの作成
SageMakerのノートブックインスタンスを立ち上げて表示されるjupyterのトップページのタブから
SageMaker Examples
↓
SageMaker Python SDK
↓
mxnet_mnist.ipynb
↓
use
でサンプルのノートブックとスクリプトをコピーして、開きます。
ノートブックインスタンスの作成についてはこちらをご参照ください。
環境変数とロールの確認
学習データ等を保存するS3の場所の指定と、学習やエンドポイントを立ち上げる際に使用するIAMロールの取得を行います。
from sagemaker import get_execution_role # S3のバケット名と保存先の接頭辞 bucket_name = 'hoge-bucket' s3_prefix = 'hoge-prefix' # カスタムコードの保存場所 custom_code_upload_location = 's3://{}/{}/customcode/mxnet'.format(bucket_name, s3_prefix) # モデルの学習結果の保存場所 model_artifacts_location = 's3://{}/{}/artifacts'.format(bucket_name, s3_prefix) # 学習時やエンドポイントの展開時に使用するIAMロール # 今回はノートブックインスタンスに使用しているIAMロールから取得してくる role = get_execution_role()
学習
entry_point(学習用スクリプト)について
スクリプトファイルmnist.pyを学習時のentry_pointとして設定します。このスクリプトの中には、train
という関数が定義されている必要があります。この関数は学習時に呼び出されるため、MXNetによる学習処理を定義します。
※詳細はSDKのGitHubでの説明をご参照ください。
今回、entry_pointとして使用するスクリプトは以下の通りです。
import logging import gzip import mxnet as mx import numpy as np import os import struct # 指定したデータを読み込む def load_data(path): with gzip.open(find_file(path, "labels.gz")) as flbl: struct.unpack(">II", flbl.read(8)) labels = np.fromstring(flbl.read(), dtype=np.int8) with gzip.open(find_file(path, "images.gz")) as fimg: _, _, rows, cols = struct.unpack(">IIII", fimg.read(16)) images = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(labels), rows, cols) images = images.reshape(images.shape[0], 1, 28, 28).astype(np.float32) / 255 return labels, images # 指定したファイルを探す def find_file(root_path, file_name): for root, dirs, files in os.walk(root_path): if file_name in files: return os.path.join(root, file_name) # 分類モデルの定義 # モデル:IN→平坦化→全結合層(64ノード)→活性化関数(ReLu)→全結合層(64ノード)→活性化関数(ReLu)→全結合層(64ノード)→ソフトマックス関数→OUT def build_graph(): data = mx.sym.var('data') data = mx.sym.flatten(data=data) fc1 = mx.sym.FullyConnected(data=data, num_hidden=128) act1 = mx.sym.Activation(data=fc1, act_type="relu") fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64) act2 = mx.sym.Activation(data=fc2, act_type="relu") fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10) return mx.sym.SoftmaxOutput(data=fc3, name='softmax') # MXNetを使った学習処理 def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus): (train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train'])) (test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test'])) # Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key # to do parallel training. shard_size = len(train_images) // len(hosts) for i, host in enumerate(hosts): if host == current_host: start = shard_size * i end = start + shard_size break batch_size = 100 train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True) val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size) logging.getLogger().setLevel(logging.DEBUG) kvstore = 'local' if len(hosts) == 1 else 'dist_sync' mlp_model = mx.mod.Module( symbol=build_graph(), context=get_train_context(num_cpus, num_gpus)) mlp_model.fit(train_iter, eval_data=val_iter, kvstore=kvstore, optimizer='sgd', optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))}, eval_metric='acc', batch_end_callback=mx.callback.Speedometer(batch_size, 100), num_epoch=25) return mlp_model # 学習を行うのはGPUかCPUか def get_train_context(num_cpus, num_gpus): if num_gpus > 0: return mx.gpu() return mx.cpu()
パラメータの設定
学習に向けて、パラメータを設定します。
from sagemaker.mxnet import MXNet mnist_estimator = MXNet(entry_point='mnist.py',# 学習用スクリプトファイルを指定(スクリプトファイルが複数の場合はsource_dirで指定出来ます) role=role, # 学習やエンドポイントの作成に使用するIAMロール名 output_path=model_artifacts_location, # モデルアーティファクトの出力先 code_location=custom_code_upload_location, # スクリプトを保存する場所 train_instance_count=1, # 学習時に使用するインスタンス数 train_instance_type='ml.m4.xlarge', # 学習時に使用するインスタンスタイプ framework_version='1.2.1', # 使用するMXNetのバージョン hyperparameters={ 'learning_rate': 0.1 # 学習率をハイパーパラメータとして設定 })
学習の実行
学習用データとテスト用データを登録し、学習を開始します。
SageMakerがインスタンスを立ち上げて、学習処理を実行し、学習終了後にインスタンスを自動的に終了させます。学習状態は随時ログが出て来るので、追うことができます。
%%time import boto3 # S3上で公開されているデータを学習用、テスト用として設定します。 region = boto3.Session().region_name train_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/train'.format(region) test_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/test'.format(region) # 学習の実行 mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})
最終エポックの検証精度をみて見ると、97.5%とかなり高い数字が出ています。
モデルの展開
エンドポイントを作成し、先ほど学習させたモデルをエンドポイントに展開します。 今回は少し触るだけなので、ノートブックに書かれている内容から変更してml.t2.mediumインスタンスを使用します。
%%time predictor = mnist_estimator.deploy(initial_instance_count=1, instance_type='ml.t2.medium')
推論
では、先ほど作成したモデルを使用して、実際に書いた数字を分類してみます。
まずは、数字を書く場所を作成します。 input.htmlにはキャンバスとボタンを描画するHTMLと、キャンバス上に文字を書けるようにするスクリプト(JavaScript)が含まれています。
from IPython.display import HTML HTML(open("input.html").read())
では、このキャンバス上に「0」を書いてみます。
書いた数字はピクセルごとに0か1が格納されたリスト形式(1*28*28
)として、変数data
に代入されています。
次に、書いた数字のピクセルデータを分類モデルをホストしているエンドポイントに投げて、推論結果を受け取ります。
response = predictor.predict(data) print('Raw prediction result:') print(response) labeled_predictions = list(zip(range(10), response[0])) print('Labeled predictions: ') print(labeled_predictions) labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1]) print('Most likely answer: {}'.format(labeled_predictions[0]))
一番下のMost like answer
に最も予測確率の高い数字と、その数字に対する予測確率が表示されています。
正しく0と分類しています。
他の数字も試してみます。
正解!
次は1を大げさに書いてみます。
ダメでした。検証精度はかなり高かったですが、汚すぎる字は難しいようです。学習データにノイズを加えたり、歪ましたりみたいなことが必要なのかもしれません。
エンドポイントの削除
余計な費用がかからないように、エンドポイントを削除します。
import sagemaker sagemaker.Session().delete_endpoint(predictor.endpoint)
さいごに
今回はAmazon SageMaker上でのMXNetを使った手書き数字の分類方法について紹介しました。SageMaker Python SDKを利用することで、MXNetを使った学習と推論が簡単に行うことができました。
HTMLで入力を作ってそのまま推論エンドポイントに投げて確認するというのは、Amazon SageMakerのノートブックでの探索的アプローチのしやすさを感じることができますね。
これからAmazon SageMakerを触ってみようと思っている方の参考になれば幸いです。 最後までお読みいただき、ありがとうございましたー!